import functools
import gym
import numpy as np
import pandas as pd
from epicare import *
from gym.envs.registration import register


# -----------------------------
# Offline Data Generation Section
# -----------------------------
def collect_trajectory(env, max_individuals=1000, seed=42):
    rng = np.random.RandomState(seed)  # Fix random number seed
    data = []
    for pid in range(max_individuals):
        # Fix seed for each individual (to ensure reproducibility)
        state = env.reset(seed=seed + pid)  
        done = False
        stage = 0
        while not done:
            action = rng.randint(env.action_space.n)  # Sample action using fixed rng
            next_state, reward, done, info = env.step(action)
            record = {
                "pid": pid,
                "stage": stage,
                "state": state.tolist(),
                "action": action,
                "reward": reward,
                "delta": info["delta"],
                "done": done,
            }
            data.append(record)
            state = next_state
            stage += 1
    return pd.DataFrame(data)


def normalize_rewards(df, eps=1e-6):
    """
    Perform min-max normalization by individual (pid + stage) to ensure reward > 0
    """
    if "pid" in df.columns:  # By individual & stage
        df["reward"] = df.groupby(["pid", "stage"])["reward"].transform(
            lambda x: (x - x.min()) / (x.max() - x.min() + eps) + eps
        )
    else:  # If no pid exists, group by stage
        df["reward"] = df.groupby("stage")["reward"].transform(
            lambda x: (x - x.min()) / (x.max() - x.min() + eps) + eps
        )
    return df


if __name__ == "__main__":
    seed = 123
    env = gym.make("EpiCare-v0")
    env.reset(seed=seed)  # Fix the internal random number of the environment

    df = collect_trajectory(env, max_individuals=1000, seed=seed)
    # Note: The original code mentions "reward normalized" in the print statement,
    # but does not actually call normalize_rewards. Uncomment the line below if normalization is needed:
    # df = normalize_rewards(df)

    df.to_csv("offline_dataset.csv", index=False)
    print("offline_dataset.csv generated successfully (seed fixed, rewards normalized)!")
    print(df.head())